How does the k-Means algorithm work?

Source file: https://shirinsplayground.netlify.app/2021/03/kmeans_101/

  1. Pick a number of clusters, k
  2. Create k random points and call each of these the center of a cluster
  3. For each point in your dataset, find the closest cluster center and assign the point to that cluster
  4. Once each point has been assigned to a cluster, calculate the new center of that cluster

Repeat steps 3 and 4 until you reach a stage when no points need to be reassigned.

Stop. You have found your k clusters and their centers!

If you want to learn more about k-Means, I would recommend this post on Medium, though be aware that the example code is all written in Python. If you are brave and want to go very deep in k-Means theory, take a look at the Wikipedia page. Or, if you would like to see one application of k-Means in R, see this blog’s post about using k-Means to help assist in image classification with Keras. For a detailed illustration of how to implement k-Means in R, along with answers to some common questions, keep reading below.

Functions

Setting up two functions:

# Define two functions for transforming a distribution of values
#  into the standard normal distribution (bell curve with mean = 0
#  and standard deviation (sd) = 1). More on this later.
normalize_values <- function(x, mean, sd) {
  (x-mean)/sd
}

unnormalize_values <- function(x, mean, sd) {
  (x*sd)+mean
}

set.seed(2021) # So you can reproduce this example

Loading the Data Sets

The data we will use for this example is from one of R’s pre-loaded datasets, quakes. It is a data.frame with 1000 rows and five columns describing earthquakes near Fiji since 1964. The columns are latitude (degrees), longitude (degrees), depth (km), magnitude (Richter scale), and the number of stations reporting the quake. The only pre-processing we will do now is to remove stations and convert this to a tibble.

quakes_raw <- quakes %>% 
  dplyr::select(-stations) %>% 
  dplyr::as_tibble()

summary(quakes_raw)
##       lat              long           depth            mag      
##  Min.   :-38.59   Min.   :165.7   Min.   : 40.0   Min.   :4.00  
##  1st Qu.:-23.47   1st Qu.:179.6   1st Qu.: 99.0   1st Qu.:4.30  
##  Median :-20.30   Median :181.4   Median :247.0   Median :4.60  
##  Mean   :-20.64   Mean   :179.5   Mean   :311.4   Mean   :4.62  
##  3rd Qu.:-17.64   3rd Qu.:183.2   3rd Qu.:543.0   3rd Qu.:4.90  
##  Max.   :-10.72   Max.   :188.1   Max.   :680.0   Max.   :6.40

Principle 3: Feature scaling (skipping the first two examples):

k-Means calculates distance to the cluster center using Euclidian distance: the length of a line segment connecting the two points. In two dimensions, this is the Pythagorean Theorem. Aha, you say! I see the problem: we are comparing magnitudes (4.0-6.4) to depth (40-680). Depth has significantly more variation (standard deviation 0.4 for magnitude vs. 215 for depth) and therefore gets overweighted when calculating distance to the mean.

We need to employ feature scaling. As a general rule, if we are comparing unlike units (meters and kilograms) or independent measurements (height in meters and circumference in meters), we should normalize values, but if units are related (petal length and petal width), we should leave them as is.

Unfortunately, many cases require judgment both on whether to scale and how to scale. This is where your expert opinion as a data analyst becomes important. For the purposes of this blog post, we will normalize all of our features, including latitude and longitude, by transforming them to standard normal distributions. The geologists might object to this methodology for normalizing (magnitude is a log scale!!), but please forgive some imprecision for the sake of illustration.

# Create a tibble to store the information we need to normalize
#  Tibble with row 1 = mean and row 2 = standard deviation
transformations <- dplyr::tibble(
  lat   = c(mean(quakes_raw$lat),   sd(quakes_raw$lat)),
  long  = c(mean(quakes_raw$long),  sd(quakes_raw$long)),
  depth = c(mean(quakes_raw$depth), sd(quakes_raw$depth)),
  mag   = c(mean(quakes_raw$mag),   sd(quakes_raw$mag))
)

# Use the convenient function we wrote earlier
quakes_normalized <- quakes_raw %>% 
  dplyr::mutate(
    lat = normalize_values(
      lat, transformations$lat[1], transformations$lat[2]
    ),
    long = normalize_values(
      long, transformations$long[1], transformations$long[2]
    ),
    depth = normalize_values(
      depth, transformations$depth[1], transformations$depth[2]
    ),
    mag = normalize_values(
      mag, transformations$mag[1], transformations$mag[2]
    )
  )

summary(quakes_normalized)
##       lat                long              depth              mag          
##  Min.   :-3.56890   Min.   :-2.27235   Min.   :-1.2591   Min.   :-1.54032  
##  1st Qu.:-0.56221   1st Qu.: 0.02603   1st Qu.:-0.9853   1st Qu.:-0.79548  
##  Median : 0.06816   Median : 0.32095   Median :-0.2987   Median :-0.05065  
##  Mean   : 0.00000   Mean   : 0.00000   Mean   : 0.0000   Mean   : 0.00000  
##  3rd Qu.: 0.59761   3rd Qu.: 0.61586   3rd Qu.: 1.0747   3rd Qu.: 0.69419  
##  Max.   : 1.97319   Max.   : 1.42812   Max.   : 1.7103   Max.   : 4.41837

With our fully-preprocessed data, let’s re-run our k-Means analysis, in four dimensions:

kclust <- kmeans(quakes_normalized, centers = 4, iter.max = 10, nstart = 5)

str(kclust)
## List of 9
##  $ cluster     : int [1:1000] 1 1 2 1 1 3 4 2 2 1 ...
##  $ centers     : num [1:4, 1:4] -0.012 -1.736 0.294 0.934 0.222 ...
##   ..- attr(*, "dimnames")=List of 2
##   .. ..$ : chr [1:4] "1" "2" "3" "4"
##   .. ..$ : chr [1:4] "lat" "long" "depth" "mag"
##  $ totss       : num 3996
##  $ withinss    : num [1:4] 594 253 340 358
##  $ tot.withinss: num 1546
##  $ betweenss   : num 2450
##  $ size        : int [1:4] 420 143 242 195
##  $ iter        : int 4
##  $ ifault      : int 0
##  - attr(*, "class")= chr "kmeans"

Print the cluster assignments:

kclust
## K-means clustering with 4 clusters of sizes 420, 143, 242, 195
## 
## Cluster means:
##           lat       long      depth         mag
## 1 -0.01202836  0.2224322  1.0714971 -0.25340992
## 2 -1.73556818  0.3865266 -0.8379212  0.32263747
## 3  0.29398939  0.8951082 -0.7466850 -0.07834941
## 4  0.93380886 -1.8733897 -0.7667092  0.40643880
## 
## Clustering vector:
##    [1] 1 1 2 1 1 3 4 2 2 1 1 4 1 1 4 3 4 1 1 1 1 4 1 2 1 1 4 1 1 3 1 4 3 1 3 1 4
##   [38] 1 1 4 2 3 1 3 4 2 2 4 1 3 1 3 4 1 1 1 1 1 1 1 1 3 1 4 1 3 1 1 1 3 3 3 4 1
##   [75] 1 1 3 4 1 2 2 1 1 3 1 3 4 1 3 3 4 4 1 4 3 1 2 3 4 1 3 1 1 2 1 3 2 4 2 2 3
##  [112] 1 1 1 1 1 4 4 4 4 4 3 1 1 2 4 1 1 2 3 3 1 4 1 1 4 3 3 2 3 4 1 4 2 1 3 3 4
##  [149] 1 1 2 4 3 4 4 1 4 1 4 4 1 1 4 2 2 2 2 3 1 4 1 1 1 3 1 2 1 3 1 3 1 3 3 1 1
##  [186] 3 1 1 1 3 3 4 1 1 2 1 3 3 1 1 1 1 3 2 4 3 1 1 2 1 2 3 2 3 1 1 2 1 3 1 3 4
##  [223] 3 1 1 4 3 3 2 4 1 1 1 3 1 1 2 1 4 1 3 3 4 3 3 1 1 3 1 4 4 4 1 4 2 1 3 4 1
##  [260] 1 3 1 4 1 3 2 4 4 1 1 3 1 1 1 1 1 1 1 1 1 2 1 3 1 3 1 1 3 1 1 1 4 1 1 3 3
##  [297] 1 1 2 4 1 3 3 1 1 1 1 1 1 2 4 4 1 1 3 1 1 4 1 4 4 3 1 2 4 2 1 4 3 4 3 1 2
##  [334] 4 1 2 2 2 2 2 2 2 2 2 2 2 2 2 3 2 2 4 1 2 2 1 4 3 2 4 1 3 1 2 4 2 1 3 1 3
##  [371] 1 3 1 1 3 3 1 3 3 3 4 4 2 4 1 2 1 1 4 2 3 2 1 1 1 1 1 3 1 4 2 4 2 1 3 1 1
##  [408] 4 1 2 3 3 4 1 3 4 4 2 2 3 1 2 3 1 2 2 1 1 4 1 1 1 1 1 3 2 4 1 1 3 4 3 4 1
##  [445] 3 3 3 1 1 1 2 1 4 3 3 4 1 3 1 1 3 1 1 1 2 1 4 3 1 2 3 3 3 4 3 2 2 2 1 3 1
##  [482] 3 3 2 1 2 2 1 1 1 2 4 1 1 3 4 3 2 1 3 3 3 3 4 1 1 4 4 4 1 2 3 3 1 1 3 4 3
##  [519] 3 1 3 1 2 4 2 1 4 4 1 2 4 4 3 4 3 4 1 4 4 1 4 4 4 4 4 4 4 1 3 1 1 4 4 3 3
##  [556] 3 3 2 1 4 1 3 2 3 1 3 2 1 3 2 4 3 1 3 1 3 1 1 1 2 4 1 4 1 1 3 4 1 1 1 1 1
##  [593] 4 4 3 4 4 1 3 3 2 2 1 1 1 2 3 3 1 2 2 4 4 1 3 1 3 4 1 4 2 2 1 1 4 1 2 1 4
##  [630] 1 1 2 2 1 3 1 3 4 3 1 1 4 2 3 1 3 2 3 2 1 1 1 4 1 1 4 1 2 1 1 2 1 1 1 1 1
##  [667] 1 3 1 3 1 4 1 1 4 2 3 1 1 1 4 3 1 1 3 4 1 1 1 1 1 1 1 1 1 1 1 3 4 1 4 1 1
##  [704] 1 3 3 1 3 3 1 3 4 3 4 1 1 1 2 3 1 4 1 1 1 4 4 1 1 1 1 1 2 2 1 4 3 3 1 1 1
##  [741] 1 3 3 2 3 2 2 3 1 1 1 3 1 4 1 4 3 3 4 3 3 1 3 3 4 4 3 3 4 3 3 1 3 1 1 1 3
##  [778] 1 4 1 4 4 2 3 3 2 4 1 3 1 1 1 1 2 1 2 4 1 3 1 2 1 3 1 1 3 1 1 1 1 1 3 4 3
##  [815] 1 1 3 3 1 1 1 3 1 1 4 1 4 3 3 3 2 3 3 1 1 3 4 1 3 4 1 3 2 4 1 1 1 1 1 1 3
##  [852] 4 4 3 1 3 1 3 2 4 1 1 1 1 4 4 1 3 4 4 1 4 4 1 1 1 1 1 1 1 3 1 4 4 2 1 4 4
##  [889] 3 2 4 1 4 3 1 1 3 1 3 1 2 3 2 1 1 1 1 4 4 2 1 4 3 1 4 4 3 3 3 1 4 3 1 1 2
##  [926] 1 4 2 2 4 3 1 1 3 3 3 1 3 1 1 1 2 1 2 3 2 1 2 3 3 1 2 3 1 1 1 2 3 1 3 3 4
##  [963] 4 3 3 1 1 1 4 2 2 2 3 4 3 3 3 2 1 4 1 4 2 2 3 4 4 4 1 3 1 1 3 1 3 1 4 3 3
## [1000] 4
## 
## Within cluster sum of squares by cluster:
## [1] 594.2867 252.9391 340.4248 358.4645
##  (between_SS / total_SS =  61.3 %)
## 
## Available components:
## 
## [1] "cluster"      "centers"      "totss"        "withinss"     "tot.withinss"
## [6] "betweenss"    "size"         "iter"         "ifault"

Denormalize the data and prepare for plotting:

point_assignments <- broom::augment(kclust, quakes_normalized) %>% 
  dplyr::mutate(
    lat = unnormalize_values(
      lat, transformations$lat[1], transformations$lat[2]
    ),
    long = unnormalize_values(
      long, transformations$long[1], transformations$long[2]
    ),
    depth = unnormalize_values(
      depth, transformations$depth[1], transformations$depth[2]
    ),
    mag = unnormalize_values(
      mag, transformations$mag[1], transformations$mag[2]
    )
  )

cluster_info <- broom::tidy(kclust) %>% 
  dplyr::mutate(
    lat = unnormalize_values(
      lat, transformations$lat[1], transformations$lat[2]
    ),
    long = unnormalize_values(
      long, transformations$long[1], transformations$long[2]
    ),
    depth = unnormalize_values(
      depth, transformations$depth[1], transformations$depth[2]
    ),
    mag = unnormalize_values(
      mag, transformations$mag[1], transformations$mag[2]
    )
  )

model_stats <- broom::glance(kclust)

head(point_assignments)
## # A tibble: 6 x 5
##     lat  long depth   mag .cluster
##   <dbl> <dbl> <dbl> <dbl> <fct>   
## 1 -20.4  182.   562   4.8 1       
## 2 -20.6  181.   650   4.2 1       
## 3 -26    184.    42   5.4 2       
## 4 -18.0  182.   626   4.1 1       
## 5 -20.4  182.   649   4   1       
## 6 -19.7  184.   195   4   3

Print the cluster assignments:

head(point_assignments)
## # A tibble: 6 x 5
##     lat  long depth   mag .cluster
##   <dbl> <dbl> <dbl> <dbl> <fct>   
## 1 -20.4  182.   562   4.8 1       
## 2 -20.6  181.   650   4.2 1       
## 3 -26    184.    42   5.4 2       
## 4 -18.0  182.   626   4.1 1       
## 5 -20.4  182.   649   4   1       
## 6 -19.7  184.   195   4   3

Model statistics:

model_stats
## # A tibble: 1 x 4
##   totss tot.withinss betweenss  iter
##   <dbl>        <dbl>     <dbl> <int>
## 1  3996        1546.     2450.     4

Cluster information:

cluster_info
## # A tibble: 4 x 7
##     lat  long depth   mag  size withinss cluster
##   <dbl> <dbl> <dbl> <dbl> <int>    <dbl> <fct>  
## 1 -20.7  181.  542.  4.52   420     594. 1      
## 2 -29.4  182.  131.  4.75   143     253. 2      
## 3 -19.2  185.  150.  4.59   242     340. 3      
## 4 -15.9  168.  146.  4.78   195     358. 4

Plot the data with clusters:

plotly::plot_ly() %>% 
  plotly::add_trace(
    data = point_assignments,
    x = ~long, y = ~lat, z = ~depth*-1, size = ~mag,
    color = ~.cluster,
    type = "scatter3d", mode = "markers",
    marker = list(symbol = "circle", sizemode = "diameter"),
    sizes = c(5, 30)
  ) %>% 
  plotly::layout(scene = list(
    xaxis = list(title = "Longitude"),
    yaxis = list(title = "Latitude"),
    zaxis = list(title = "Depth")
  ))
## Warning: `arrange_()` is deprecated as of dplyr 0.7.0.
## Please use `arrange()` instead.
## See vignette('programming') for more help
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
## Warning: `line.width` does not currently support multiple values.

## Warning: `line.width` does not currently support multiple values.

## Warning: `line.width` does not currently support multiple values.

## Warning: `line.width` does not currently support multiple values.